{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0' ## please on a machine with GPU\n",
    "import sys\n",
    "sys.path.append('.')\n",
    "import torch\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "torch.backends.cudnn.allow_tf32 = True\n",
    "torch.backends.cudnn.deterministic = True\n",
    "from diffuser.guides.policies import Policy\n",
    "import diffuser.datasets as datasets\n",
    "import diffuser.utils as utils\n",
    "from diffuser.models import GaussianDiffusionPB\n",
    "from diffuser.guides.policies_compose import PolicyCompose\n",
    "from datetime import datetime\n",
    "import os.path as osp\n",
    "from tqdm import tqdm\n",
    "from diffuser.utils.jupyter_utils import suppress_stdout\n",
    "\n",
    "## You need to download 'maze2d-static1-base' from OneDrive Link in README.md to launch\n",
    "class Parser(utils.Parser):\n",
    "    config: str = \"config/rm2d/rSmaze_nw6_hExt05_exp.py\"\n",
    "\n",
    "#---------------------------------- setup ----------------------------------#\n",
    "\n",
    "## training args\n",
    "args_train = Parser().parse_args('diffusion', from_jupyter=True)\n",
    "args = Parser().parse_args('plan', from_jupyter=True)\n",
    "\n",
    "args.savepath = None # osp.join(args.savepath, sub_dir)\n",
    "args.load_unseen_maze = True\n",
    "\n",
    "## load dataset here, dataset is a string: name of the env\n",
    "print('args.dataset', type(args.dataset), args.dataset)\n",
    "print('args.dataset_eval', type(args.dataset_eval), args.dataset_eval)\n",
    "use_normed_wallLoc = args_train.dataset_config.get('use_normed_wallLoc', False)\n",
    "\n",
    "## use the trained env or eval env\n",
    "load_unseen_maze = True # args.load_unseen_maze # True False\n",
    "with suppress_stdout():\n",
    "    #---------------------------------- loading ----------------------------------#\n",
    "    from pb_diff_envs.environment.static.maze2d_rand_wgrp_43 import RandRectangleGroup_43\n",
    "    train_env_list: RandRectangleGroup_43\n",
    "    train_env_list = datasets.load_environment(args.dataset, is_eval=load_unseen_maze)\n",
    "    train_normalizer = utils.load_datasetNormalizer(train_env_list.dataset_url, args_train, train_env_list)\n",
    "\n",
    "    ld_config = dict(env_instance=train_env_list) \n",
    "    diffusion_experiment = utils.load_potential_diffusion_model(args.logbase, args.dataset, \\\n",
    "                args_train.exp_name, epoch=args.diffusion_epoch, ld_config=ld_config)\n",
    "    diffusion = diffusion_experiment.ema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffuser.utils.rm2d_render import RandStaticMazeRenderer\n",
    "import imageio\n",
    "np.set_printoptions(precision=3)\n",
    "## define an env's wall locations\n",
    "# wloc_source = 'from_env'\n",
    "wloc_source = 'user_input'\n",
    "if wloc_source == 'from_env':\n",
    "    env_id = 17\n",
    "    wloc = train_env_list.wallLoc_list[env_id]\n",
    "    print(f'wall locations of env {env_id}:')\n",
    "    print(train_env_list.wallLoc_list[env_id])\n",
    "elif wloc_source == 'user_input':\n",
    "    env_id = 5 # placholder\n",
    "    wloc = np.array([\n",
    "        [0.983, 1.844],\n",
    "        [1.444, 3.904],\n",
    "        [2.32,  1.612],\n",
    "        [3.495, 0.713],\n",
    "        [3.9,   4.307],\n",
    "        [4.107, 2.546],])\n",
    "\n",
    "    ## wall size must be fixed\n",
    "    train_env_list.create_env_by_pos(env_id, wloc, train_env_list.hExt_list[env_id])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diffusion: GaussianDiffusionPB\n",
    "bs = 10\n",
    "policy = Policy(diffusion, train_normalizer, use_ddim=True)\n",
    "\n",
    "\n",
    "## choose a start and goal position\n",
    "start = np.array( [ [ 0.4, 4.7 ] ], dtype=np.float32 ).repeat( bs, 0 )\n",
    "goal = np.array( [ [4.8, 1.2] ], dtype=np.float32  ).repeat( bs, 0 )\n",
    "\n",
    "print( start.shape )\n",
    "wloc_np = wloc # wloc in numpy\n",
    "\n",
    "wloc_tensor = utils.to_torch(wloc_np).reshape(1, -1).repeat( (bs, 1) )\n",
    "print( f'wloc_np: ', wloc_np.shape )\n",
    "\n",
    "cond = {0: start, \n",
    "        47: goal}\n",
    "\n",
    "samples = policy(cond, batch_size=-1, wall_locations=wloc_tensor, use_normed_wallLoc=use_normed_wallLoc, return_diffusion=True)\n",
    "\n",
    "unnm_traj = samples[1].observations\n",
    "print(f'unnm_traj: {unnm_traj.shape}')\n",
    "\n",
    "\n",
    "root_dir = './visualization/'\n",
    "print('traj shape:', unnm_traj.shape)\n",
    "rd = RandStaticMazeRenderer(train_env_list)\n",
    "tmp_path = f'{root_dir}/plot_rm2d_base.png'\n",
    "img = rd.composite( tmp_path, unnm_traj,  np.array([env_id]))\n",
    "img = imageio.imread(tmp_path)\n",
    "utils.plt_img(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Another way of Rendering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import imageio\n",
    "from pb_diff_envs.environment.static.rand_maze2d_renderer import RandMaze2DRenderer\n",
    "\n",
    "\n",
    "lemp_rd = RandMaze2DRenderer(num_walls_c=[len(wloc_np), 0], fig_dpi=300)\n",
    "hExt = train_env_list.hExt_list[0, 0].tolist() # [0.5, 0.5]\n",
    "hExts = np.array([hExt,] * len(wloc_np))\n",
    "\n",
    "path_id = 9\n",
    "img = lemp_rd.renders( unnm_traj[ path_id ], train_env_list.maze_size, wloc_np , hExts )\n",
    "print(f'resol: {img.shape}')\n",
    "tmp_path = f'./plot_rm2d_base.png'\n",
    "# imageio.imsave( tmp_path, img )\n",
    "utils.plt_img(img, dpi=120)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### create a denoising video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pb_diff_envs.environment.static.rand_maze2d_renderer import RandMaze2DRenderer\n",
    "env = train_env_list.model_list[env_id]\n",
    "\n",
    "traj_idx = 0 # which traj\n",
    "\n",
    "x0_perstep = samples[3] # shape: B, n_steps, Horizon, Dim\n",
    "x0_perstep.shape\n",
    "\n",
    "n_steps = len(x0_perstep[0]) ## number of denosing steps + 1\n",
    "trajs_ps = [] ## per-step trajs\n",
    "imgs_ps = []\n",
    "\n",
    "mz_renderer = RandMaze2DRenderer(num_walls_c=[30,10], fig_dpi=200)\n",
    "mz_renderer.up_right = 'empty'\n",
    "mz_renderer.config['no_draw_traj_line'] = False\n",
    "mz_renderer.config['no_draw_col_info'] = True # False\n",
    "# mz_renderer.num_walls_c = [0, 30]\n",
    "\n",
    "for i_st in range(1, n_steps):\n",
    "    traj_ps = x0_perstep[traj_idx][i_st] #\n",
    "    trajs_ps.append(traj_ps)\n",
    "    ### -----\n",
    "    env_wlocs = env.wall_locations\n",
    "    # hExt = plan_env_list.hExt_list[0, 0].tolist()\n",
    "    # env_hExts = np.concatenate(env.hExt_list_c, axis=0)\n",
    "    env_hExts = env.wall_hExts\n",
    "    img = mz_renderer.renders( traj_ps , train_env_list.maze_size, env_wlocs, env_hExts )\n",
    "    imgs_ps.append(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## show the video\n",
    "rootdir = './visualization/'\n",
    "print('rootdir:', rootdir)\n",
    "# import diffuser.utils.video as duv; reload(duv)\n",
    "from diffuser.utils.video import save_images_to_mp4, read_mp4_to_numpy\n",
    "\n",
    "mp4_name = f'{rootdir}/base_{env_id}.mp4' # {utils.get_time()}\n",
    "## Save to video\n",
    "save_images_to_mp4(imgs_ps, mp4_name, fps=15, st_sec=1, end_sec=1)\n",
    "## load and save\n",
    "import mediapy as media\n",
    "from mediapy import show_video, read_video\n",
    "show_video(read_mp4_to_numpy(mp4_name), fps=15, width=400)\n",
    "print(f'mp4_name: {mp4_name}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
